{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#default_exp collab\n",
    "#default_class_lvl 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.tabular.all import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Collaborative filtering\n",
    "\n",
    "> Tools to quickly get the data and train models suitable for collaborative filtering"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This module contains all the high-level functions you need in a collaborative filtering application to assemble your data, get a model and train it with a `Learner`. We will go other those in order but you can also check the [collaborative filtering tutorial](http://docs.fast.ai/tutorial.collab)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Gather the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class TabularCollab(TabularPandas):\n",
    "    \"Instance of `TabularPandas` suitable for collaborative filtering (with no continuous variable)\"\n",
    "    with_cont=False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is just to use the internal of the tabular application, don't worry about it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class CollabDataLoaders(DataLoaders):\n",
    "    \"Base `DataLoaders` for collaborative filtering.\"\n",
    "    @delegates(DataLoaders.from_dblock)\n",
    "    @classmethod\n",
    "    def from_df(cls, ratings, valid_pct=0.2, user_name=None, item_name=None, rating_name=None, seed=None, path='.', **kwargs):\n",
    "        \"Create a `DataLoaders` suitable for collaborative filtering from `ratings`.\"\n",
    "        user_name   = ifnone(user_name,   ratings.columns[0])\n",
    "        item_name   = ifnone(item_name,   ratings.columns[1])\n",
    "        rating_name = ifnone(rating_name, ratings.columns[2])\n",
    "        cat_names = [user_name,item_name]\n",
    "        splits = RandomSplitter(valid_pct=valid_pct, seed=seed)(range_of(ratings))\n",
    "        to = TabularCollab(ratings, [Categorify], cat_names, y_names=[rating_name], y_block=TransformBlock(), splits=splits)\n",
    "        return to.dataloaders(path=path, **kwargs)\n",
    "\n",
    "    @classmethod\n",
    "    def from_csv(cls, csv, **kwargs):\n",
    "        \"Create a `DataLoaders` suitable for collaborative filtering from `csv`.\"\n",
    "        return cls.from_df(pd.read_csv(csv), **kwargs)\n",
    "\n",
    "CollabDataLoaders.from_csv = delegates(to=CollabDataLoaders.from_df)(CollabDataLoaders.from_csv)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This class should not be used directly, one of the factory methods should be preferred instead. All those factory methods accept as arguments:\n",
    "\n",
    "- `valid_pct`: the random percentage of the dataset to set aside for validation (with an optional `seed`)\n",
    "- `user_name`: the name of the column containing the user (defaults to the first column)\n",
    "- `item_name`: the name of the column containing the item (defaults to the second column)\n",
    "- `rating_name`: the name of the column containing the rating (defaults to the third column)\n",
    "- `path`: the folder where to work\n",
    "- `bs`: the batch size\n",
    "- `val_bs`: the batch size for the validation `DataLoader` (defaults to `bs`)\n",
    "- `shuffle_train`: if we shuffle the training `DataLoader` or not\n",
    "- `device`: the PyTorch device to use (defaults to `default_device()`)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"CollabDataLoaders.from_df\" class=\"doc_header\"><code>CollabDataLoaders.from_df</code><a href=\"__main__.py#L4\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>CollabDataLoaders.from_df</code>(**`ratings`**, **`valid_pct`**=*`0.2`*, **`user_name`**=*`None`*, **`item_name`**=*`None`*, **`rating_name`**=*`None`*, **`seed`**=*`None`*, **`path`**=*`'.'`*, **`bs`**=*`64`*, **`val_bs`**=*`None`*, **`shuffle_train`**=*`True`*, **`device`**=*`None`*)\n",
       "\n",
       "Create a [`DataLoaders`](/data.core#DataLoaders) suitable for collaborative filtering from `ratings`."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(CollabDataLoaders.from_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's see how this works on an example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>userId</th>\n",
       "      <th>movieId</th>\n",
       "      <th>rating</th>\n",
       "      <th>timestamp</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>73</td>\n",
       "      <td>1097</td>\n",
       "      <td>4.0</td>\n",
       "      <td>1255504951</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>561</td>\n",
       "      <td>924</td>\n",
       "      <td>3.5</td>\n",
       "      <td>1172695223</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>157</td>\n",
       "      <td>260</td>\n",
       "      <td>3.5</td>\n",
       "      <td>1291598691</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>358</td>\n",
       "      <td>1210</td>\n",
       "      <td>5.0</td>\n",
       "      <td>957481884</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>130</td>\n",
       "      <td>316</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1138999234</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   userId  movieId  rating   timestamp\n",
       "0      73     1097     4.0  1255504951\n",
       "1     561      924     3.5  1172695223\n",
       "2     157      260     3.5  1291598691\n",
       "3     358     1210     5.0   957481884\n",
       "4     130      316     2.0  1138999234"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "path = untar_data(URLs.ML_SAMPLE)\n",
    "ratings = pd.read_csv(path/'ratings.csv')\n",
    "ratings.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>userId</th>\n",
       "      <th>movieId</th>\n",
       "      <th>rating</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>157</td>\n",
       "      <td>1265</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>130</td>\n",
       "      <td>2858</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>481</td>\n",
       "      <td>1196</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>105</td>\n",
       "      <td>597</td>\n",
       "      <td>3.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>128</td>\n",
       "      <td>597</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>587</td>\n",
       "      <td>5952</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>111</td>\n",
       "      <td>2571</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>105</td>\n",
       "      <td>356</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>77</td>\n",
       "      <td>5349</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>119</td>\n",
       "      <td>1240</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dls = CollabDataLoaders.from_df(ratings, bs=64)\n",
    "dls.show_batch()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"CollabDataLoaders.from_csv\" class=\"doc_header\"><code>CollabDataLoaders.from_csv</code><a href=\"__main__.py#L16\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>CollabDataLoaders.from_csv</code>(**`csv`**, **`valid_pct`**=*`0.2`*, **`user_name`**=*`None`*, **`item_name`**=*`None`*, **`rating_name`**=*`None`*, **`seed`**=*`None`*, **`path`**=*`'.'`*, **`bs`**=*`64`*, **`val_bs`**=*`None`*, **`shuffle_train`**=*`True`*, **`device`**=*`None`*)\n",
       "\n",
       "Create a [`DataLoaders`](/data.core#DataLoaders) suitable for collaborative filtering from `csv`."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(CollabDataLoaders.from_csv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = CollabDataLoaders.from_csv(path/'ratings.csv', bs=64)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "fastai provides two kinds of models for collaborative filtering: a dot-product model and a neural net. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class EmbeddingDotBias(Module):\n",
    "    \"Base dot model for collaborative filtering.\"\n",
    "    def __init__(self, n_factors, n_users, n_items, y_range=None):\n",
    "        self.y_range = y_range\n",
    "        (self.u_weight, self.i_weight, self.u_bias, self.i_bias) = [Embedding(*o) for o in [\n",
    "            (n_users, n_factors), (n_items, n_factors), (n_users,1), (n_items,1)\n",
    "        ]]\n",
    "\n",
    "    def forward(self, x):\n",
    "        users,items = x[:,0],x[:,1]\n",
    "        dot = self.u_weight(users)* self.i_weight(items)\n",
    "        res = dot.sum(1) + self.u_bias(users).squeeze() + self.i_bias(items).squeeze()\n",
    "        if self.y_range is None: return res\n",
    "        return torch.sigmoid(res) * (self.y_range[1]-self.y_range[0]) + self.y_range[0]\n",
    "\n",
    "    @classmethod\n",
    "    def from_classes(cls, n_factors, classes, user=None, item=None, y_range=None):\n",
    "        \"Build a model with `n_factors` by inferring `n_users` and  `n_items` from `classes`\"\n",
    "        if user is None: user = list(classes.keys())[0]\n",
    "        if item is None: item = list(classes.keys())[1]\n",
    "        res = cls(n_factors, len(classes[user]), len(classes[item]), y_range=y_range)\n",
    "        res.classes,res.user,res.item = classes,user,item\n",
    "        return res\n",
    "\n",
    "    def _get_idx(self, arr, is_item=True):\n",
    "        \"Fetch item or user (based on `is_item`) for all in `arr`\"\n",
    "        assert hasattr(self, 'classes'), \"Build your model with `EmbeddingDotBias.from_classes` to use this functionality.\"\n",
    "        classes = self.classes[self.item] if is_item else self.classes[self.user]\n",
    "        c2i = {v:k for k,v in enumerate(classes)}\n",
    "        try: return tensor([c2i[o] for o in arr])\n",
    "        except Exception as e:\n",
    "            print(f\"\"\"You're trying to access {'an item' if is_item else 'a user'} that isn't in the training data.\n",
    "                  If it was in your original data, it may have been split such that it's only in the validation set now.\"\"\")\n",
    "\n",
    "    def bias(self, arr, is_item=True):\n",
    "        \"Bias for item or user (based on `is_item`) for all in `arr`\"\n",
    "        idx = self._get_idx(arr, is_item)\n",
    "        layer = (self.i_bias if is_item else self.u_bias).eval().cpu()\n",
    "        return to_detach(layer(idx).squeeze(),gather=False)\n",
    "\n",
    "    def weight(self, arr, is_item=True):\n",
    "        \"Weight for item or user (based on `is_item`) for all in `arr`\"\n",
    "        idx = self._get_idx(arr, is_item)\n",
    "        layer = (self.i_weight if is_item else self.u_weight).eval().cpu()\n",
    "        return to_detach(layer(idx),gather=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model is built with `n_factors` (the length of the internal vectors), `n_users` and `n_items`. For a given user and item, it grabs the corresponding weights and bias and returns\n",
    "``` python\n",
    "torch.dot(user_w, item_w) + user_b + item_b\n",
    "```\n",
    "Optionally, if `y_range` is passed, it applies a `SigmoidRange` to that result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x,y = dls.one_batch()\n",
    "model = EmbeddingDotBias(50, len(dls.classes['userId']), len(dls.classes['movieId']), y_range=(0,5)\n",
    "                        ).to(x.device)\n",
    "out = model(x)\n",
    "assert (0 <= out).all() and (out <= 5).all()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"EmbeddingDotBias.from_classes\" class=\"doc_header\"><code>EmbeddingDotBias.from_classes</code><a href=\"__main__.py#L17\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>EmbeddingDotBias.from_classes</code>(**`n_factors`**, **`classes`**, **`user`**=*`None`*, **`item`**=*`None`*, **`y_range`**=*`None`*)\n",
       "\n",
       "Build a model with `n_factors` by inferring `n_users` and  `n_items` from `classes`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(EmbeddingDotBias.from_classes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`y_range` is passed to the main init. `user` and `item` are the names of the keys for users and items in `classes` (default to the first and second key respectively). `classes` is expected to be a dictionary key to list of categories like the result of `dls.classes` in a `CollabDataLoaders`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'userId': (#101) ['#na#',15,17,19,23,30,48,56,73,77...],\n",
       " 'movieId': (#101) ['#na#',1,10,32,34,39,47,50,110,150...]}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dls.classes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's see how it can be used in practice:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = EmbeddingDotBias.from_classes(50, dls.classes,  y_range=(0,5)\n",
    "                                     ).to(x.device)\n",
    "out = model(x)\n",
    "assert (0 <= out).all() and (out <= 5).all()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Two convenience methods are added to easily access the weights and bias when a model is created with `EmbeddingDotBias.from_classes`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"EmbeddingDotBias.weight\" class=\"doc_header\"><code>EmbeddingDotBias.weight</code><a href=\"__main__.py#L42\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>EmbeddingDotBias.weight</code>(**`arr`**, **`is_item`**=*`True`*)\n",
       "\n",
       "Weight for item or user (based on `is_item`) for all in `arr`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(EmbeddingDotBias.weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The elements of `arr` are expected to be class names (which is why the model needs to be created with `EmbeddingDotBias.from_classes`)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mov = dls.classes['movieId'][42] \n",
    "w = model.weight([mov])\n",
    "test_eq(w, model.i_weight(tensor([42])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"EmbeddingDotBias.bias\" class=\"doc_header\"><code>EmbeddingDotBias.bias</code><a href=\"__main__.py#L36\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>EmbeddingDotBias.bias</code>(**`arr`**, **`is_item`**=*`True`*)\n",
       "\n",
       "Bias for item or user (based on `is_item`) for all in `arr`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(EmbeddingDotBias.bias)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The elements of `arr` are expected to be class names (which is why the model needs to be created with `EmbeddingDotBias.from_classes`)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mov = dls.classes['movieId'][42] \n",
    "b = model.bias([mov])\n",
    "test_eq(b, model.i_bias(tensor([42])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export \n",
    "class EmbeddingNN(TabularModel):\n",
    "    \"Subclass `TabularModel` to create a NN suitable for collaborative filtering.\"\n",
    "    @delegates(TabularModel.__init__)\n",
    "    def __init__(self, emb_szs, layers, **kwargs):\n",
    "        super().__init__(emb_szs=emb_szs, n_cont=0, out_sz=1, layers=layers, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h2 id=\"EmbeddingNN\" class=\"doc_header\"><code>class</code> <code>EmbeddingNN</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h2>\n",
       "\n",
       "> <code>EmbeddingNN</code>(**`emb_szs`**, **`layers`**, **`ps`**=*`None`*, **`embed_p`**=*`0.0`*, **`y_range`**=*`None`*, **`use_bn`**=*`True`*, **`bn_final`**=*`False`*, **`bn_cont`**=*`True`*) :: [`TabularModel`](/tabular.model#TabularModel)\n",
       "\n",
       "Subclass [`TabularModel`](/tabular.model#TabularModel) to create a NN suitable for collaborative filtering."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(EmbeddingNN)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`emb_szs` should be a list of two tuples, one for the users, one for the items, each tuple containing the number of users/items and the corresponding embedding size (the function `get_emb_sz` can give a good default). All the other arguments are passed to `TabularModel`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "emb_szs = get_emb_sz(dls.train_ds, {})\n",
    "model = EmbeddingNN(emb_szs, [50], y_range=(0,5)\n",
    "                   ).to(x.device)\n",
    "out = model(x)\n",
    "assert (0 <= out).all() and (out <= 5).all()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a `Learner`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following function lets us quickly create a `Learner` for collaborative filtering from the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "@log_args(to_return=True, but_as=Learner.__init__)\n",
    "@delegates(Learner.__init__)\n",
    "def collab_learner(dls, n_factors=50, use_nn=False, emb_szs=None, layers=None, config=None, y_range=None, loss_func=None, **kwargs):\n",
    "    \"Create a Learner for collaborative filtering on `dls`.\"\n",
    "    emb_szs = get_emb_sz(dls, ifnone(emb_szs, {}))\n",
    "    if loss_func is None: loss_func = MSELossFlat()\n",
    "    if config is None: config = tabular_config()\n",
    "    if y_range is not None: config['y_range'] = y_range\n",
    "    if layers is None: layers = [n_factors]\n",
    "    if use_nn: model = EmbeddingNN(emb_szs=emb_szs, layers=layers, **config)\n",
    "    else:      model = EmbeddingDotBias.from_classes(n_factors, dls.classes, y_range=y_range)\n",
    "    return Learner(dls, model, loss_func=loss_func, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If `use_nn=False`, the model used is an `EmbeddingDotBias` with `n_factors` and `y_range`. Otherwise, it's a `EmbeddingNN` for which you can pass `emb_szs` (will be inferred from the `dls` with `get_emb_sz` if you don't provide any), `layers` (defaults to `[n_factors]`) `y_range`, and a `config` that you can create with `tabular_config` to customize your model. \n",
    "\n",
    "`loss_func` will default to `MSELossFlat` and all the other arguments are passed to `Learner`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = collab_learner(dls, y_range=(0,5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>2.521979</td>\n",
       "      <td>2.541627</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.cutmix.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import *\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
